# -*- coding: utf-8 -*-
"""
Copyright 2021 NXP
All rights reserved.

SPDX-License-Identifier: BSD-3-Clause

author: Kaleb Belete
"""
import torch
import torch.nn as nn


channel = 32
class NeuralNet(nn.Module):                                         #Create model class to inherit from nn.Module                              
    def __init__(self):                     
        super(NeuralNet, self).__init__()   
        self.layer1 = nn.Sequential(                                # Container Class                             
            nn.Conv2d(3, channel, kernel_size=3, padding=1),        # Extracts 32 (32x32) feature maps from RGB image
            nn.BatchNorm2d(channel),                                # Normalize extracted features to each batch        
            nn.ReLU(),                                              # Activate normalized data for non-linearity (3)
            nn.Conv2d(channel, channel, kernel_size=3, padding=1),  # Extracts 32 (32x32) feature maps from processed feature maps (1)
            nn.BatchNorm2d(channel),                                # Normalize extracted features to each batch
            nn.ReLU(),                                              # Activate normalized data for non-linearity (3)                                          
            nn.MaxPool2d(kernel_size=2, stride=2))                  # Pooling to reduce training parameters and generalize learned features
        self.layer2 = nn.Sequential(
            nn.Conv2d(channel, channel*2, kernel_size=3, padding=1),    # Extracts 64 (16x16) feature maps from processed pooled features
            nn.BatchNorm2d(channel*2),                                  # Normalize extracted features to each batch
            nn.ReLU(),                                                  # Activate normalized features for non-linearity
            nn.MaxPool2d(kernel_size=2, stride=2))                      # Pooling to reduce training parameters and generalize learned features
        self.layer3 = nn.Sequential(
            nn.Conv2d(channel*2, channel*2, kernel_size=3, padding=1),  # Extracts 64 (8x8) feature maps from processed pooled features
            nn.BatchNorm2d(channel*2),                                  # Normalize extracted features to each batch
            nn.ReLU(),                                                  # Activate normalized data for non-linearity
            nn.MaxPool2d(kernel_size=2, stride=2))                      # Reduce training parameters (4x4) connected to Linear Layer
        self.fcLayer = nn.Linear(4 * 4 * channel*2, 10)                 # Convert Learned training into a classifier output

        

    def forward(self, x):                       #Path through the model using nn.Sequential container class
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.reshape(out.size(0), -1)      # Reshape 2D weight matrix to 1D Array
        out = self.fcLayer(out)


        return out

# Paths for data and models
model_save_path = '..\\Pytorch Models\\'
data_path = '..\\dataset\\'
model_load = '..\\Pytorch Models\\'

#Create model instance
model = NeuralNet()
#Load model weights from .pth file
state_dict = torch.load(model_load + "Cifar_retrain.pth")
# Load weights to model instance
model.load_state_dict(state_dict)

# Save the model with input shape (Batch size, channels (RGB), img width, img height)
input_shape = torch.randn(1,3,32,32)
#Specify Model's Input and Output node names 
torch.onnx.export(model, input_shape, model_save_path + 'Cifar_convert.onnx', input_names= ['input'], output_names= ['output'])
